{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "08845d58-06c5-4484-b8e4-d6171d7ead15",
   "metadata": {},
   "source": [
    "## 1、安装环境\n",
    "本案例基于Python>=3.8，请在您的计算机上安装好Python；  \n",
    "另外，您的计算机上至少要有一张英伟达/昇腾显卡（显存要求大概32GB左右可以跑）。  \n",
    "我们需要安装以下这几个Python库，在这之前，请确保你的环境内已安装了pytorch以及CUDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "68f78224-fba3-4015-a647-626030a576fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: modelscope==1.22.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 1)) (1.22.0)\n",
      "Requirement already satisfied: transformers==4.51.3 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 2)) (4.51.3)\n",
      "Requirement already satisfied: datasets==3.2.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 3)) (3.2.0)\n",
      "Requirement already satisfied: accelerate==1.6.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 4)) (1.6.0)\n",
      "Requirement already satisfied: peft==0.11.1 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 5)) (0.11.1)\n",
      "Requirement already satisfied: pandas in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 6)) (2.3.1)\n",
      "Requirement already satisfied: addict in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 7)) (2.4.0)\n",
      "Requirement already satisfied: swanlab in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from -r requirements.txt (line 8)) (0.6.10)\n",
      "Requirement already satisfied: requests>=2.25 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from modelscope==1.22.0->-r requirements.txt (line 1)) (2.32.4)\n",
      "Requirement already satisfied: tqdm>=4.64.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from modelscope==1.22.0->-r requirements.txt (line 1)) (4.67.1)\n",
      "Requirement already satisfied: urllib3>=1.26 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from modelscope==1.22.0->-r requirements.txt (line 1)) (2.5.0)\n",
      "Requirement already satisfied: filelock in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (3.19.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (0.35.0)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (1.26.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (25.0)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (6.0.2)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (2024.11.6)\n",
      "Requirement already satisfied: tokenizers<0.22,>=0.21 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (0.21.0)\n",
      "Requirement already satisfied: safetensors>=0.4.3 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from transformers==4.51.3->-r requirements.txt (line 2)) (0.5.3)\n",
      "Requirement already satisfied: pyarrow>=15.0.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from datasets==3.2.0->-r requirements.txt (line 3)) (19.0.0)\n",
      "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from datasets==3.2.0->-r requirements.txt (line 3)) (0.3.7)\n",
      "Requirement already satisfied: xxhash in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from datasets==3.2.0->-r requirements.txt (line 3)) (3.5.0)\n",
      "Requirement already satisfied: multiprocess<0.70.17 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from datasets==3.2.0->-r requirements.txt (line 3)) (0.70.15)\n",
      "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets==3.2.0->-r requirements.txt (line 3)) (2023.10.0)\n",
      "Requirement already satisfied: aiohttp in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from datasets==3.2.0->-r requirements.txt (line 3)) (3.12.15)\n",
      "Requirement already satisfied: psutil in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from accelerate==1.6.0->-r requirements.txt (line 4)) (5.9.0)\n",
      "Requirement already satisfied: torch>=2.0.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from accelerate==1.6.0->-r requirements.txt (line 4)) (2.4.0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.30.0->transformers==4.51.3->-r requirements.txt (line 2)) (4.14.1)\n",
      "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.30.0->transformers==4.51.3->-r requirements.txt (line 2)) (1.1.10)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pandas->-r requirements.txt (line 6)) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pandas->-r requirements.txt (line 6)) (2025.2)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pandas->-r requirements.txt (line 6)) (2025.2)\n",
      "Requirement already satisfied: boto3>=1.35.49 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (1.40.35)\n",
      "Requirement already satisfied: botocore in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (1.40.35)\n",
      "Requirement already satisfied: click in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (8.2.1)\n",
      "Requirement already satisfied: platformdirs>=4.2.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (4.3.7)\n",
      "Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (5.29.3)\n",
      "Requirement already satisfied: pydantic<3 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (2.11.9)\n",
      "Requirement already satisfied: pyecharts>=2.0.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (2.0.8)\n",
      "Requirement already satisfied: pynvml in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (13.0.1)\n",
      "Requirement already satisfied: rich<14.0.0,>=13.6.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (13.9.4)\n",
      "Requirement already satisfied: setuptools in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (78.1.1)\n",
      "Requirement already satisfied: swankit==0.2.4 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (0.2.4)\n",
      "Requirement already satisfied: wrapt>=1.17.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from swanlab->-r requirements.txt (line 8)) (1.17.0)\n",
      "Requirement already satisfied: annotated-types>=0.6.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pydantic<3->swanlab->-r requirements.txt (line 8)) (0.7.0)\n",
      "Requirement already satisfied: pydantic-core==2.33.2 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pydantic<3->swanlab->-r requirements.txt (line 8)) (2.33.2)\n",
      "Requirement already satisfied: typing-inspection>=0.4.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pydantic<3->swanlab->-r requirements.txt (line 8)) (0.4.1)\n",
      "Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from rich<14.0.0,>=13.6.0->swanlab->-r requirements.txt (line 8)) (4.0.0)\n",
      "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from rich<14.0.0,>=13.6.0->swanlab->-r requirements.txt (line 8)) (2.19.1)\n",
      "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (2.6.1)\n",
      "Requirement already satisfied: aiosignal>=1.4.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (1.4.0)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (24.3.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (1.7.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (6.6.3)\n",
      "Requirement already satisfied: propcache>=0.2.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (0.3.2)\n",
      "Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (1.20.1)\n",
      "Requirement already satisfied: idna>=2.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from yarl<2.0,>=1.17.0->aiohttp->datasets==3.2.0->-r requirements.txt (line 3)) (3.7)\n",
      "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from boto3>=1.35.49->swanlab->-r requirements.txt (line 8)) (1.0.1)\n",
      "Requirement already satisfied: s3transfer<0.15.0,>=0.14.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from boto3>=1.35.49->swanlab->-r requirements.txt (line 8)) (0.14.0)\n",
      "Requirement already satisfied: six>=1.5 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->-r requirements.txt (line 6)) (1.17.0)\n",
      "Requirement already satisfied: mdurl~=0.1 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=13.6.0->swanlab->-r requirements.txt (line 8)) (0.1.2)\n",
      "Requirement already satisfied: jinja2 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pyecharts>=2.0.0->swanlab->-r requirements.txt (line 8)) (3.1.6)\n",
      "Requirement already satisfied: prettytable in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pyecharts>=2.0.0->swanlab->-r requirements.txt (line 8)) (3.16.0)\n",
      "Requirement already satisfied: simplejson in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pyecharts>=2.0.0->swanlab->-r requirements.txt (line 8)) (3.20.1)\n",
      "Requirement already satisfied: charset_normalizer<4,>=2 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from requests>=2.25->modelscope==1.22.0->-r requirements.txt (line 1)) (3.3.2)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from requests>=2.25->modelscope==1.22.0->-r requirements.txt (line 1)) (2025.8.3)\n",
      "Requirement already satisfied: sympy in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from torch>=2.0.0->accelerate==1.6.0->-r requirements.txt (line 4)) (1.14.0)\n",
      "Requirement already satisfied: networkx in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from torch>=2.0.0->accelerate==1.6.0->-r requirements.txt (line 4)) (3.5)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from jinja2->pyecharts>=2.0.0->swanlab->-r requirements.txt (line 8)) (3.0.2)\n",
      "Requirement already satisfied: wcwidth in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from prettytable->pyecharts>=2.0.0->swanlab->-r requirements.txt (line 8)) (0.2.13)\n",
      "Requirement already satisfied: nvidia-ml-py>=12.0.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from pynvml->swanlab->-r requirements.txt (line 8)) (13.580.82)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/envs/MyPython/lib/python3.12/site-packages (from sympy->torch>=2.0.0->accelerate==1.6.0->-r requirements.txt (line 4)) (1.3.0)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffac7f79-d009-4f7b-be29-3987465d17b9",
   "metadata": {},
   "source": [
    "## 2、准备数据集\n",
    "本案例使用的是 hk-o1aw-sft-16k 数据集，该数据集主要被用于法律问题对话模型  \n",
    "该数据集由14000多条数据组成，每条数据包含prompt、thinking、answer 三列"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "175de1f0-8998-491d-9566-494a3aaa93b1",
   "metadata": {},
   "source": [
    "### 2.1、数据集转换"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c103475-84ee-4a71-a12e-b14687fb6778",
   "metadata": {},
   "source": [
    "这里我们取prompt、thinking、answer 这三列, 分别对应question、think、answer：  \n",
    "question：用户提出的问题，即模型的输入  \n",
    "think：模型的思考过程。大家如果用过DeepSeek R1的话，回复中最开始的思考过程就是这个。  \n",
    "answer：模型思考完成后，回复的内容。  \n",
    "我们的训练任务，便是希望微调后的大模型，能够根据question，给用户一个think+answer的组合回复，并且think和answer在网页上的展示有区分。  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01ea750f-50ac-4474-99d2-1c1a2413f711",
   "metadata": {},
   "source": [
    "理清需求后，我们设计这样一个数据集样例：  \n",
    "{  \n",
    "    \"question\": \"在交通事故中，如何判断某一行为是否属于‘干扰’事故现场的证据？请讨论法律适用的标准。\",   \n",
    "    \"think\": \"**分析法律条文**根据《道路交通条例》第57(1)条，任何人在未得警务人员授权下，移动或干扰事故现场的车辆或证据，即属犯罪。这意味着，...\",  \n",
    "    \"answer\": \"判断某行为是否属于‘干扰’事故现场的证据，需考虑行为是否未经授权、是否改变了现场证据，以及是否出于紧急情况...\",  \n",
    "}\n",
    "\n",
    "在训练代码执行时，会将think和answer按下面这样的格式组合成一条完整回复：  \n",
    "<think>  \n",
    "分析法律条文  根据《道路交通条例》第57(1)条，任何人在未得警务人员授权下，移动或干扰事故现场的车辆或证据，即属犯罪。这意味着，...  \n",
    "</think>  \n",
    "\n",
    "判断某行为是否属于‘干扰’事故现场的证据，需考虑行为是否未经授权..."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29b79c08-c4d6-40aa-983f-063726229c05",
   "metadata": {},
   "source": [
    "### 2.2、数据集下载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "42b3e438-9c13-436a-ae69-cb25d3dc9ba1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/MyPython/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据集字段： {'prompt': Value(dtype='string', id=None), 'thinking': Value(dtype='string', id=None), 'answer': Value(dtype='string', id=None)}\n",
      "示例数据： {'prompt': '在交通事故中，如何判断某一行为是否属于‘干扰’事故现场的证据？请讨论法律适用的标准。', 'thinking': '\\n**分析法律条文**\\n根据《道路交通条例》第57(1)条，任何人在未得警务人员授权下，移动或干扰事故现场的车辆或证据，即属犯罪。这意味着，判断某行为是否属于‘干扰’的关键在于是否未经授权地改变了事故现场的状态。\\n\\n**考虑例外情况**\\n第57(2)条提供了例外情况，即如果移动或干扰是为了挽救性命、灭火或应付其他紧急事故，则不构成犯罪。这表明，判断行为是否属于‘干扰’还需考虑行为的目的和紧急性。\\n\\n**探索多种方法判断‘干扰’**\\n方法一是通过直接观察行为是否改变了现场证据。方法二是评估行为是否经过警务人员授权。方法三是分析行为是否出于紧急情况的需要。结合这三种方法，可以全面判断某行为是否属于‘干扰’。\\n\\n**验证和排除潜在误解**\\n可能的误解是认为任何移动都构成‘干扰’，但实际上，紧急情况下的移动不构成犯罪。此外，未经授权的移动即使没有明显改变现场，也可能被视为‘干扰’。\\n', 'answer': '判断某行为是否属于‘干扰’事故现场的证据，需考虑行为是否未经授权、是否改变了现场证据，以及是否出于紧急情况。只有综合考虑这些因素，才能准确判断。'}\n",
      "\n",
      "hk-o1aw-sft-16k 数据集分割完成：\n",
      "总数据量：14363 条\n",
      "训练集大小：12926 条\n",
      "验证集大小：1437 条\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "import json\n",
    "import random\n",
    "\n",
    "# 设置随机种子以确保可重复性\n",
    "random.seed(42)\n",
    "\n",
    "# 1. 加载 hk-o1aw-sft-16k 数据集（Hugging Face 官方地址）\n",
    "dataset = load_dataset(\"hkair-lab/hk-o1aw-sft-16k\", split=\"train\")  # 主数据集为 train 分割\n",
    "\n",
    "# 2. 转换为列表（便于打乱和分割）\n",
    "data_list = list(dataset)\n",
    "\n",
    "# 3. 查看数据集字段和示例（确认结构，可注释）\n",
    "print(\"数据集字段：\", dataset.features)\n",
    "print(\"示例数据：\", data_list[0])\n",
    "# 4. 随机打乱数据\n",
    "random.shuffle(data_list)\n",
    "\n",
    "# 5. 按 9:1 分割训练集和验证集\n",
    "split_idx = int(len(data_list) * 0.9)\n",
    "train_data = data_list[:split_idx]\n",
    "val_data = data_list[split_idx:]\n",
    "\n",
    "# 6. 保存训练集（保留原生字段）\n",
    "with open('train.jsonl', 'w', encoding='utf-8') as f:\n",
    "    for item in train_data:\n",
    "        # 直接提取原生字段，确保格式为 {\"question\":..., \"think\":..., \"answer\":...}\n",
    "        json.dump({\n",
    "            \"question\": item[\"prompt\"],\n",
    "            \"think\": item[\"thinking\"],\n",
    "            \"answer\": item[\"answer\"]\n",
    "        }, f, ensure_ascii=False)\n",
    "        f.write('\\n')\n",
    "\n",
    "# 7. 保存验证集\n",
    "with open('val.jsonl', 'w', encoding='utf-8') as f:\n",
    "    for item in val_data:\n",
    "        json.dump({\n",
    "            \"question\": item[\"prompt\"],\n",
    "            \"think\": item[\"thinking\"],\n",
    "            \"answer\": item[\"answer\"]\n",
    "        }, f, ensure_ascii=False)\n",
    "        f.write('\\n')\n",
    "\n",
    "# 输出处理结果\n",
    "print(f\"\\nhk-o1aw-sft-16k 数据集分割完成：\")\n",
    "print(f\"总数据量：{len(data_list)} 条\")\n",
    "print(f\"训练集大小：{len(train_data)} 条\")\n",
    "print(f\"验证集大小：{len(val_data)} 条\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e563caed-9f2c-401a-ace1-51e2bd81f540",
   "metadata": {},
   "source": [
    "完成后，你的代码目录下会出现训练集train.jsonl和验证集val.jsonl文件。\n",
    "\n",
    "至此，数据集部分完成。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cfc9deb-997d-4767-b19f-fe3606e117f4",
   "metadata": {},
   "source": [
    "## 3. 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "833f13e3-c989-49eb-9fee-5fde05ac349e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from modelscope import snapshot_download, AutoTokenizer\n",
    "from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq\n",
    "\n",
    "# 在modelscope上下载Qwen模型到本地目录下\n",
    "model_dir = snapshot_download(\"Qwen/Qwen3-1.7B\", cache_dir=\"./\", revision=\"master\")\n",
    "\n",
    "# Transformers加载模型权重\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_dir, device_map=\"auto\", torch_dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9fd188b-9c72-46c8-8194-eda0aeeee3ef",
   "metadata": {},
   "source": [
    "## 3.1、 登录SwanLab\n",
    "1. 前往[swanlab](https://swanlab.cn/space/~/settings)复制你的API Key，粘贴到下面的代码中\n",
    "2. 如果你不希望将登录信息保存到该计算机中，可将`save=True`去掉（每次运行训练需要重新执行下面的代码块）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fcfe707b-9ba9-45cc-bcb1-b45595fd40cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/opt/anaconda3/envs/MyPython/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\" for \n",
       "Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/opt/anaconda3/envs/MyPython/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\" for \n",
       "Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/MyPython/lib/python3.12/site-packages/swanlab/data/run/metadata/hardware/gpu/nvidia.py:10: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n",
      "  import pynvml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import swanlab\n",
    "\n",
    "swanlab.login(api_key=\"fW9cvzge70yYySaMUaj3H\", save=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37c22f22-2319-4029-9c29-2117625beec4",
   "metadata": {},
   "outputs": [],
   "source": [
    "3.2、"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
